#! /usr/bin/env python3
# -*-python-*-
#
###########################################################################
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
###########################################################################
import argparse
import itertools
import json
import os
import os.path
import random
import re
import shlex
import shutil
import subprocess
import sys

script_dir = os.path.realpath(os.path.dirname(sys.argv[0]))
top_dir = os.path.dirname(script_dir)
sys.path.append(os.path.join(top_dir, 'py-lib'))
import project_defs
import tool_config

###########################################################################

default_clock = (2.0, 0.5)
default_part = 'xcvu9p-flga2104-2-e'

###########################################################################

_pipeline_json = "pipeline.json"

###########################################################################

_clean_defs = {
    'none': [],
    'csim': [],
    'sim': [],
    'most': ['csim','sim'],
}

###########################################################################

# Quote open-brackets, dollars and backslashes.  There is no need to
# quote close-brackets as they are only special after open-brackets.
_tcl_quote_re = re.compile(r'([\[$\\])')

def tcl_quote(o):
    if type(o) in (tuple, list):
        return "{%s}" % (" ".join(tcl_quote(x) for x in o))
    return ('"%s"' % _tcl_quote_re.sub(r'\\\1', str(o)))

###########################################################################

_special_re = re.compile(r'[^0-9a-zA-Z]+')

###########################################################################

step_defs = (
    ('reset', ()),
    ('c_sim', ('reset',)),
    ('synth', ('c_sim',)),
    ('cosim', ('synth',)),
    ('export', ('synth',)),
)

class tcl_writer(object):
    def __init__(self, fh):
        self.__clock = (2.0, 0.5)
        self.__fh = fh
        self.__project_open = False
        self.__part = default_part

    def set_part(self, part):
        self.__part = part

    def set_clock(self, period, uncertainty):
        self.__clock_period = period
        self.__clock_uncertainty = uncertainty

    def set_module(self, m):
        self.__module_name = m['subdir']
        self.__module_top = m['top']
        self.__module_sources = m['sources']
        self.__module_testbench = m['testbench']

    def ensure_project_open(self):
        if self.__project_open:
            return

        self.__fh.write('open_project %s\n' %
                        (tcl_quote(self.__module_name),))
        self.__fh.write('open_solution solution1 -flow_target vitis\n')
        self.__project_open = True

    def close_project(self):
        self.ensure_project_open()
        self.__fh.write('close_project\n')
        self.__project_open = False

    def show_banner(self, msg):
        fh = self.__fh
        msg_line = "### %s ###" % msg
        hashes = "#" * len(msg_line)
        fh.write("puts { }\n")
        fh.write("puts {%s}\n" % hashes)
        fh.write("puts {%s}\n" % msg_line)
        fh.write("puts {%s}\n" % hashes)

    def step_reset(self):
        self.__fh.write('open_project -reset %s\n' %
                        (tcl_quote(self.__module_name),))
        self.__fh.write('set_top %s\n' %
                        (tcl_quote(self.__module_top),))

        for name in self.__module_sources:
            self.__fh.write('add_files -cflags "$cflags" %s\n' %
                            (tcl_quote(name),))

        for name in self.__module_testbench:
            self.__fh.write('add_files -cflags "$cflags" -tb %s\n' %
                            (tcl_quote(name),))

        self.__fh.write('open_solution -reset solution1 -flow_target vitis\n')
        self.__fh.write('set_part {%s}\n' % (self.__part,))
        self.__fh.write('create_clock -name default -period %s\n' %
                        (self.__clock_period,))
        self.__fh.write('set_clock_uncertainty %s default\n' %
                        (self.__clock_uncertainty,))
        self.__project_open = True

    def step_c_sim(self):
        self.ensure_project_open()
        self.__fh.write('csim_design -argv "$module_args" -ldflags "$ldflags"\n')

    def step_synth(self):
        self.ensure_project_open()
        self.__fh.write('csynth_design\n')

    def step_cosim(self):
        self.ensure_project_open()
        self.__fh.write('cosim_design -argv "$module_args" -trace_level all'
                        ' -ldflags "$ldflags"\n')

    def step_export(self):
        self.ensure_project_open()
        self.__fh.write('''
export_design \
  -flow syn \
  -rtl verilog \
  -format xo \
  -vendor Xilinx \
  -library %s
''' % (tcl_quote(self.__module_name),))

###########################################################################

class dummy_arg(argparse.Action):
    def __init__(self, **kwargs):
        argparse.Action.__init__(self, **kwargs)
        self.nargs = 0

    def __call__(self, parser, namespace, values, opt_str):
        pass

###########################################################################

class app(object):
    def __init__(self):
        self.__ldflags = []
        self.__processes = {}
        self.__pipe_config = None

    def select_steps(self):
        args = self.__args

        # Take the remaining steps from the command line.
        steps = set(x for s in args.steps for x in s.split(',') if x != '')

        # If 'all' was specified then replace it with all the steps.
        if 'all' in steps:
            steps.discard('all')
            for step, preds in step_defs:
                steps.add(step)

        # If the user specifies two steps, then we want to execute those
        # steps and all the steps in between.  This means we want to
        # execute a step if it is specified or if it is both a predecessor
        # of a specified step and a successor of another specified step.
        # Note that predecessor/successor here and below means eventual
        # predecessor/successor, not immediate predecessor/successor.
        #
        # We will update pred_steps to contain all the steps which were
        # specified or a predecessor of a specified step.  We will update
        # succ_steps to contain all the steps which were specified or a
        # successors of a specified step.  The steps we want to execute
        # will be those which are in both pred_steps and succ_steps.
        pred_steps = set(steps)
        succ_steps = set(steps)

        # Progress backwards through the step definitions.  If the current
        # step is in pred_steps then insert its predecessors into
        # pred_steps too.  This will cause a ripple effect so that all
        # predecessors of specified steps are included in pred_steps.
        for step, preds in reversed(step_defs):
            steps.discard(step)
            if step in pred_steps:
                pred_steps.update(preds)

        # Progress forwards through the step definitions.  If the current
        # step is in succ_steps then insert its successors into succ_steps
        # too.  This will cause a ripple effect so that all the successors
        # of a specified step to be included too.
        for step, preds in step_defs:
            for p in preds:
                if p in succ_steps:
                    succ_steps.add(step)
                    break

        # Any remaining steps were specified but not found in the step
        # definitions.
        if len(steps) != 0:
            sys.stderr.write("%s: Unknown steps %s\n" %
                             (sys.argv[0],
                              ", ".join(repr(s) for s in sorted(steps))))
            sys.exit(1)

        # Update the set of steps with those which are in both pred_steps
        # and succ_steps.
        self.__steps = pred_steps.intersection(succ_steps)

    #######################################################################

    def parse_args(self, argv):
        p = argparse.ArgumentParser(
            description='Build an HLS module',
        )

        p.add_argument('--bsub', default=False, action='store_true',
                       help='Submit jobs using the bsub command.')
        p.add_argument('-B', '--build-dir', default=os.path.join(top_dir, "build"),
                       help='Set the Nanotube build directory.')
        p.add_argument('--clean', default=[], action='append',
                       help='Clean various parts of the output directory.')
        p.add_argument('-c', '--config', default=None,
                       help='Set the configuration file.')
        p.add_argument('-C', '--clock', default=None, action='store',
                       help='Set the clock period and uncertainty.')
        p.add_argument('-g', '--gui', default=False, action='store_true',
                       help='Start the GUI.')
        p.add_argument('-j', '--jobs', action='store', type=int, default=1,
                       help='Set the maximum number of parallel jobs.')
        p.add_argument('-l', '--list-modules', default=False,
                       action='store_true',
                       help='List the modules.')
        p.add_argument('-L', '--lib-dir', default=[], action='append',
                       help='Add a linker library directory.')
        p.add_argument('-m', '--modules', default=[], action='append',
                       help='Select the modules to build.')
        p.add_argument('-n', '--no-act', default=False, action='store_true',
                       help='Do not run Vitis HLS.')
        p.add_argument('--no-bsub-lib', default=False, action='store_true',
                       help='Do not add external/bsub_lib when using --bsub.')
        p.add_argument('-p', '--part', default=default_part, action='store',
                        help='Set the FPGA part to target')
        p.add_argument('-r', '--reset', default=False, action='store_true',
                       help='Reset the build.')
        p.add_argument('-S', '--set', default=[], action='append',
                       help='Set a tool option.')
        p.add_argument('-s', '--steps', default=[], action='append',
                       help='Set the list of steps.')
        p.add_argument('-t', '--tool', dest='tool', default='vitis_hls',
                       action='store',
                       help='Set the tool to run.')
        p.add_argument('-v', '--verbose', default=False, action='store_true',
                       help='Show verbose output.')
        p.add_argument('--vitis-hls', dest='tool',
                       action='store_const', const='vitis_hls',
                       help='Compile using Vitis HLS.')

        p.add_argument('input_dir', action=dummy_arg, metavar='[input_dir]',
                       help='Optional input directory.')
        p.add_argument('output_dir', action=dummy_arg,
                       help='The output directory.')
        p.add_argument('arguments', nargs='*',
                       help='Command line arguments for simulation.')

        args = p.parse_args(argv[1:])

        args.script_dir = os.path.dirname(argv[0])

        if len(args.steps) == 0:
            if args.reset:
                args.steps = ['reset']
            elif args.gui:
                args.steps = []
            else:
                args.steps = ['all']
        elif args.reset:
            args.steps.append('reset')

        if len(args.arguments) >= 1:
            p = os.path.join(args.arguments[0], _pipeline_json)
            if os.path.isfile(p):
                args.input_dir = args.arguments.pop(0)

        if len(args.arguments) < 1:
            sys.stderr.write("%s: Missing directory argument.\n" %
                             (argv[0],))
            sys.exit(1)

        args.directory = args.arguments.pop(0)
        while os.path.basename(args.directory) == "":
            args.directory = os.path.dirname(args.directory)

        if args.clean == []:
            args.clean = ['most']
        clean_args = [ y for x in args.clean for y in x.split(",") ]
        self.__clean = set(clean_args)
        while clean_args:
            arg = clean_args.pop()
            if arg not in _clean_defs:
                sys.stderr.write("%s: Unknown clean argument %r.\n" %
                                 (argv[0], arg))
                sys.exit(1)
            for x in _clean_defs[arg]:
                if x not in self.__clean:
                    self.__clean.add(x)
                    clean_args.append(x)

        if args.config == None and args.input_dir == None:
            sys.stderr.write("%s: No config files was"
                             " specified.\n" % argv[0])
            sys.exit(1)

        if args.clock == None:
            clock = None
        else:
            try:
                clock = tuple(float(x) for x in args.clock.split(","))
            except Exception as e:
                sys.stderr.write("%s: Invalid clock specification:"
                                 " %s.\n" % (argv[0], e))
                sys.exit(1)

            if len(clock) == 1:
                clock = clock + (0.5,)
            elif len(clock) != 2:
                sys.stderr.write("%s: Invalid clock specification.\n" %
                                 (argv[0],))
                sys.exit(1)

        bsub_lib = os.path.join(top_dir, "external/bsub_lib")
        if ( args.bsub and not args.no_bsub_lib and
             os.path.exists(bsub_lib) ):
            args.lib_dir.append(bsub_lib)

        for d in args.lib_dir:
            path = os.path.realpath(d)
            self.__ldflags.append("-L"+path)

        tool_opts = {}
        for opt in args.set:
            parts = opt.split("=",1)
            if len(parts) != 2:
                sys.stderr.write("%s: Invalid tool option: %r\n" %
                                 (argv[0], opt))
                sys.exit(1)
            var,val = parts
            tool_opts[var] = val

        tc = tool_config.config(args.build_dir, tool_opts)

        tool = args.tool
        if '/' in tool:
            tool_path = tool
        else:
            tool_path = tc.find_command(tool)['path']

        self.__part = args.part
        self.__args = args
        self.__clock = clock
        self.__tool_config = tc
        self.__tool_path = tool_path

    #######################################################################

    def read_config(self):
        filename = self.__args.config
        if filename == None:
            filename = os.path.join(self.__args.input_dir, _pipeline_json)
        self.__src_dir = os.path.dirname(filename)
        f = open(filename, "r")
        t = json.load(f)

        is_module = ('module_top' in t)
        is_pipeline = ('channels' in t and 'stages' in t)

        succ_count = int(is_pipeline) + int(is_module)
        if succ_count  < 1:
            sys.stderr.write("%s: Could not recognise the JSON structure.\n" %
                             (sys.argv[0],))
            sys.exit(1)
        elif succ_count > 1:
            sys.stderr.write("%s: The JSON structure is ambiguous.\n" %
                             (sys.argv[0],))
            sys.exit(1)

        if is_module:
            self.load_module_config(t)
        elif is_pipeline:
            self.load_pipeline_config(t)

        if self.__clock == None:
            self.__clock = default_clock

    def load_module_config(self, t):
        if self.__clock == None:
            clock = t.get("module_clock")
            self.__clock = tuple(clock)
        b,d = os.path.split(self.__args.directory)
        if d == "":
            d = "."
        self.__csim_args = t.get('module_args', ())
        self.__run_dir = b

        rel_src_dir = os.path.relpath(self.__src_dir, b)

        self.__testbench = [ os.path.join(rel_src_dir, name)
                             for name in t['module_testbench'] ]

        sources = [ os.path.join(rel_src_dir, name)
                    for name in t['module_sources'] ]

        self.__modules = [
            {
                'sources': sources,
                'subdir': d,
                'top': t['module_top'],
            },
        ]

    def add_test_harness(self):
        prog_dir = os.path.dirname(sys.argv[0])
        top_dir = os.path.realpath(os.path.join(prog_dir, ".."))
        build_dir = os.path.realpath(self.__args.build_dir)
        ldflags = self.__ldflags
        for dep_type, dep_name in project_defs.lib_defs['test_harness']:
            if dep_type == 'file':
                if os.path.basename(dep_name) == 'test_harness_main.o':
                    rel_path = "testing/harness/test_harness_main.cpp"
                    self.__testbench.append(os.path.join(top_dir, rel_path))
                else:
                    ldflags.append(os.path.join(build_dir, dep_name))
            else:
                assert(dep_type == 'lib')
                ldflags.append('-l'+dep_name)

    def load_pipeline_config(self, t):
        self.__pipe_config = t
        self.__csim_args = []
        self.__run_dir = self.__args.directory
        rel_src_dir = os.path.relpath(self.__src_dir, self.__run_dir)
        self.__testbench = [ os.path.join(rel_src_dir, 'poll_thread.cc') ]

        modules = []
        for s in t['stages']:
            name = "stage_" + str(s['thread_id'])
            modules.append({
                'sources': [os.path.join(rel_src_dir, name+".cc")],
                'subdir': name,
                'top': name,
            })
        self.__modules = modules
        self.add_test_harness()

    #######################################################################

    def select_modules(self):
        module_index = {}
        for i,m in enumerate(self.__modules):
            # Insert the module into the index.
            module_index[m['top']] = i

            # Create the testbench list.  This includes the source
            # files from all the other modules.
            testbench = list(self.__testbench)
            for other in (self.__modules[:i] + self.__modules[i+1:]):
                testbench.extend(other['sources'])
            m['testbench'] = testbench

        # Do not remove any modules by default.
        if len(self.__args.modules) == 0:
            return

        # Parse the modules command line argument.
        module_specs = [ y for x in self.__args.modules
                         for y in x.split(',') ]
        module_set = set()
        for spec in module_specs:
            spec_parts = tuple(module_index.get(x, x)
                               for x in spec.split('-', 1))
            spec = "-".join(str(x) for x in spec_parts)
            try:
                spec_parts = tuple(int(x) for x in spec_parts)
            except:
                sys.stderr.write("%s: Invalid module spec '%s'.\n" %
                                 (sys.argv[0], spec))
                sys.exit(1)

            if ( spec_parts[0] < 0 or
                 spec_parts[0] > spec_parts[-1] or
                 spec_parts[-1] >= len(self.__modules) ):
                sys.stderr.write("%s: Module number out of range: %s.\n" %
                                 (sys.argv[0], spec))
                sys.exit(1)

            for x in range(spec_parts[0], spec_parts[-1]+1):
                module_set.add(x)

        selected = tuple(self.__modules[i] for i in sorted(module_set))
        self.__modules = selected

    #######################################################################

    def list_modules(self):
        print("Modules:")
        for m in self.__modules:
            print("  "+m['top'])

    #######################################################################

    def verify(self):
        if self.__args.gui and len(self.__modules) != 1:
            sys.stderr.write('%s: There are %d modules, not 1 with --gui'
                             ' specified.\n' %
                             (sys.argv[0], len(self.__modules)))
            sys.exit(1)

    #######################################################################

    def create_directory(self):
        proj_dir = self.__args.directory
        if not os.path.exists(proj_dir):
            os.makedirs(proj_dir)
        if self.__pipe_config != None:
            # Write the config file.
            basename = "nanotube_hls_build.json"
            filename = os.path.join(proj_dir, basename)
            pid = os.getpid()
            tmpname = filename + "." + str(pid)
            fh = open(tmpname, "w")
            json.dump(self.__pipe_config, fh)
            fh.close()
            os.replace(tmpname, filename)

    #######################################################################

    def get_cflags(self):
        prog_file = sys.argv[0]
        prog_dir = os.path.dirname(prog_file)
        inc_dir = os.path.realpath(os.path.join(prog_dir, "../include"))
        rel_inc_dir = os.path.relpath(inc_dir, self.__run_dir)
        tool_sym = "NANOTUBE_USING_"+(self.__args.tool.upper())
        tool_sym = _special_re.sub('_', tool_sym)
        args = [
            "-I"+rel_inc_dir,
            "-std=c++14",
            "-fexceptions",
            "-D"+tool_sym,
            "-Wall",
            "-Wno-unknown-pragmas",
            "-Wno-sign-compare",
            "-Wno-unused-variable",
            "-Wno-cpp",
            "-Wno-address",
        ]

        if self.__args.tool == 'vitis_hls':
            args.append("-Wno-unused-label")

        # Do not use -Werror with Vitis HLS due to CR-1128691
        if self.__args.tool != 'vitis_hls':
            args.append("-Werror")

        return " ".join(args)

    def create_script(self, module):
        mod_name = module['top']
        filename = os.path.join(self.__args.directory, mod_name+".tcl")
        arguments = " ".join(itertools.chain(self.__csim_args,
                                             self.__args.arguments))
        cflags = self.get_cflags()

        fh = open(filename, "w")
        writer = tcl_writer(fh)

        fh.write("set cflags %s\n" % (tcl_quote(cflags),))
        fh.write("set ldflags %s\n" % (tcl_quote(self.__ldflags),))
        fh.write("set module_args %s\n" % (tcl_quote(arguments)))
        writer.set_clock(self.__clock[0], self.__clock[1])
        writer.set_part(self.__part)

        writer.set_module(module)

        for step, pred in step_defs:
            if step in self.__steps:
                writer.show_banner("Running step '%s' on '%s'." %
                                   (step, module['top']))
                method = getattr(writer, "step_"+step)
                method()
                fh.write("\n")

        writer.close_project()

        writer.show_banner("Exiting.")
        fh.write("quit\n")
        fh.close()

    def create_scripts(self):
        for module in self.__modules:
            self.create_script(module)

    #######################################################################

    def wait_one_proc(self):
        assert len(self.__processes) != 0
        while True:
            # Wait for a child to exit.
            pid,rc = os.wait()
            proc,desc,callback,callback_arg = self.__processes[pid]

            if not os.WIFSIGNALED(rc):
                break

            # Report a child being stopped.  It's not clear why this
            # would happen.
            sig = os.WSTOPSIG(rc)
            sys.stderr.write("%s: %s stopped with signal %d.\n" %
                             (sys.argv[0], desc, sig))

        # Remove it from the outstanding processes.
        del self.__processes[pid]

        # Make sure the subprocess object performs its wait before any
        # other children are forked.  This is to avoid the slim
        # possibility of the PID being reused for another child which
        # could cause the Popen object to wait for the wrong child.
        dummy = proc.wait()

        if os.WIFEXITED(rc):
            code = os.WEXITSTATUS(rc)
            if code == 0:
                sys.stderr.write("Completed %s.\n" % (desc,))
            else:
                sys.stderr.write("%s: %s returned exit code %d.\n" %
                                 (sys.argv[0], desc, code))

            if callback != None:
                callback(callback_arg, code)

        else:
            assert os.WIFSIGNALED(rc)
            sig = os.WTERMSIG(rc)
            sys.stderr.write("%s: %s exited with signal %d.\n" %
                             (sys.argv[0], desc, sig))

            if callback != None:
                callback(callback_arg, -sig)

    def wait_all_procs(self):
        while len(self.__processes) != 0:
            self.wait_one_proc()

    def run_command(self, run_dir, arg_list, desc, output=None,
                    callback=None, callback_arg=None):
        while len(self.__processes) >= max(1,self.__args.jobs):
            self.wait_one_proc()

        sys.stderr.write("Running %s in %r.\n" % (desc, run_dir))
        if self.__args.verbose:
            sys.stderr.write("Command: %s\n" %
                             (" ".join(shlex.quote(x) for x in arg_list)))

        if not self.__args.no_act:
            stdout = None
            if output != None:
                stdout = open(output, "w")
            proc = subprocess.Popen(arg_list, cwd=run_dir,
                                    stdin=subprocess.DEVNULL, stdout=stdout,
                                    stderr=subprocess.STDOUT)
            assert proc.pid not in self.__processes
            self.__processes[proc.pid] = (proc, desc, callback, callback_arg)

            if stdout != None:
                stdout.close()

    def done_hls_batch(self, name, code):
        proj_dir = self.__args.directory
        module_dir = os.path.join(proj_dir, name)
        sim_dir = os.path.join(module_dir, "solution1/sim")
        csim_dir = os.path.join(module_dir, "solution1/csim")
        if "sim" in self.__clean:
            shutil.rmtree(sim_dir, ignore_errors=True)
        if "csim" in self.__clean:
            shutil.rmtree(csim_dir, ignore_errors=True)

    def run_hls_batch(self):
        run_dir = self.__run_dir
        proj_dir = self.__args.directory
        rel_proj_dir = os.path.relpath(proj_dir, run_dir)

        for module in self.__modules:
            name = module['top']
            log_file = os.path.join(proj_dir, name+'.log')
            script = os.path.join(rel_proj_dir, name+'.tcl')

            arg_list = []
            if self.__args.bsub:
                tc = self.__tool_config
                arg_list.append('bsub')
                arg_list.extend(shlex.split(tc.get("BSUB_CORE_OPTIONS")))
                arg_list.extend(shlex.split(tc.get("BSUB_BATCH_OPTIONS")))
                arg_list.extend(shlex.split(tc.get("BSUB_EXTRA_OPTIONS")))

            arg_list.extend([self.__tool_path,
                             '-f', script, '-l', '/dev/null'])
            desc = "synthesis of "+name
            self.run_command(run_dir, arg_list, desc,
                             output=log_file,
                             callback=self.done_hls_batch,
                             callback_arg=name)
        self.wait_all_procs()

    def run_hls_gui(self, module_dir, log_file):
        run_dir = self.__run_dir

        arg_list = []
        if self.__args.bsub:
            tc = self.__tool_config
            arg_list.append('bsub')
            arg_list.extend(shlex.split(tc.get("BSUB_CORE_OPTIONS")))
            arg_list.extend(shlex.split(tc.get("BSUB_XWIN_OPTIONS")))
            arg_list.extend(shlex.split(tc.get("BSUB_EXTRA_OPTIONS")))
        arg_list.extend([self.__tool_path,
                         '-l', log_file, '-p', module_dir])
        self.run_command(run_dir, arg_list, "HLS gui")
        self.wait_all_procs()

    def run_hls(self):
        run_dir = self.__run_dir
        proj_dir = self.__args.directory

        rand_seed = random.getrandbits(32)

        print("Random seed is 0x%08x" % (rand_seed,));

        os.environ['TEST_RAND_SEED'] = str(rand_seed)

        # Run the script using HLS if necessary.
        if len(self.__steps) != 0:
            self.run_hls_batch()

        # Open the GUI if necessary.
        if self.__args.gui:
            log_file = os.path.join(proj_dir, 'hls_gui.log')
            module_dir = self.__modules[0]['subdir']
            self.run_hls_gui(module_dir, log_file)

    #######################################################################

    def run(self, argv):
        self.parse_args(argv)
        self.read_config()
        self.select_steps()
        self.select_modules()
        if self.__args.list_modules:
            self.list_modules()
            return 0
        self.verify()
        self.create_directory()
        self.create_scripts()
        self.run_hls()
        return 0

sys.exit(app().run(sys.argv))

###########################################################################
